# -*- coding: utf-8 -*-
"""
05_recreate_multiclass_data.py

This script recreates the multi-class event predictions from the raw GTD Excel file.
It runs both Hugging Face and Ollama models and saves the merged output to a
single CSV file that is used by the analysis scripts.
"""
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
import json
import os
import glob
from tqdm import tqdm

# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
RAW_EXCEL_PATH = "data/raw/globalterrorismdb_0522dist.xlsx"
OUTPUT_DIR = "data/model_outputs"
FINAL_MERGED_FILE = os.path.join(OUTPUT_DIR, 'raw_gtd_multilabel_data.csv') 
CUTOFF_DATE = '2017-01-01'

# Final model list, with gtd-2018 removed
MODELS_TO_RUN = {
    'eventdata-utd/gtd-2016': False, # is_ollama = False
    'llama3.1:latest': True,
    'gemma2:9b': True,
    'qwen2.5:14b': True
}
CONFLIBERT_LABELS = {
    0: "Assassination", 1: "Armed Assault", 2: "Bombing/Explosion", 3: "Hijacking",
    4: "Hostage Taking (Barricade Incident)", 5: "Hostage Taking (Kidnapping)",
    6: "Facility/Infrastructure Attack", 7: "Unarmed Assault", 8: "Unknown"
}
OLLAMA_PROMPT_TEMPLATE = """Classify the following event into up to three of these categories, providing probabilities for each:
Assassination, Armed Assault, Bombing/Explosion, Hijacking,
Hostage Taking (Barricade Incident), Hostage Taking (Kidnapping),
Facility/Infrastructure Attack, Unarmed Assault, Unknown

For the event, return only a single JSON object with category names as keys and probabilities as values.
Example format: {{"Armed Assault": 0.7, "Bombing/Explosion": 0.2, "Unknown": 0.1}}

Event:
"{text}"
"""

# --- Helper Functions ---
def setup_device():
    """Sets up and returns the appropriate device for computation."""
    if torch.backends.mps.is_available(): return torch.device("mps")
    if torch.cuda.is_available(): return torch.device("cuda")
    return torch.device("cpu")

def load_hf_model(model_name, device):
    """Loads a Hugging Face model and tokenizer."""
    logging.info(f"Loading Hugging Face model {model_name} to {device}...")
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def run_hf_inference(texts, model, tokenizer, device):
    """Runs inference for texts using a Hugging Face model."""
    results = []
    batch_size = 32
    for i in tqdm(range(0, len(texts), batch_size), desc=f"Processing {model.name_or_path}"):
        batch = [str(text) for text in texts[i:i+batch_size]]
        if not batch: continue
        try:
            inputs = tokenizer(batch, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()
            for p in probs:
                top3_indices = np.argsort(p)[::-1][:3]
                results.append(json.dumps({CONFLIBERT_LABELS[i]: float(p[i]) for i in top3_indices}))
        except Exception as e:
            logging.error(f"Error in HF batch {i//batch_size}: {e}")
            results.extend([json.dumps({"Unknown": 1.0})] * len(batch))
    return results

def run_ollama_inference(texts, model_name):
    """Runs inference for texts using an Ollama model."""
    import ollama
    results = []
    for text in tqdm(texts, desc=f"Processing {model_name}"):
        try:
            response = ollama.chat(
                model=model_name,
                messages=[{'role': 'user', 'content': OLLAMA_PROMPT_TEMPLATE.format(text=str(text))}],
                format='json'
            )
            results.append(response['message']['content'])
        except Exception as e:
            logging.error(f"Error with Ollama model {model_name}: {e}")
            results.append(json.dumps({"Unknown": 1.0}))
    return results

# --- Main Execution ---
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    logging.info(f"Loading and processing data from {RAW_EXCEL_PATH}...")
    df = pd.read_excel(RAW_EXCEL_PATH, engine='openpyxl')
    df['text'] = (df['summary'].fillna('') + ' ' + df['motive'].fillna('')).str.strip()
    df['date'] = pd.to_datetime(df.apply(lambda r: f"{r['iyear']}-{r['imonth']}-{r['iday']}", axis=1), errors='coerce')
    df.dropna(subset=['text', 'date'], inplace=True)
    df = df[df['text'] != '']
    df = df[df['date'] > pd.to_datetime(CUTOFF_DATE)].copy().reset_index(drop=True)
    logging.info(f"Filtered to {len(df)} rows.")

    texts_to_process = df['text'].tolist()
    device = setup_device()
    
    # Run inference and add prediction columns to the main dataframe
    for model_name, is_ollama in MODELS_TO_RUN.items():
        logging.info(f"\n--- Processing with {model_name} ---")
        if is_ollama:
            predictions = run_ollama_inference(texts_to_process, model_name)
        else:
            model, tokenizer = load_hf_model(model_name, device)
            predictions = run_hf_inference(texts_to_process, model, tokenizer, device)
            del model, tokenizer
            if device.type in ['cuda', 'mps']: torch.cuda.empty_cache() if device.type == 'cuda' else torch.mps.empty_cache()

        df[f'{model_name}_predictions'] = predictions
        
    logging.info(f"Saving final merged results to {FINAL_MERGED_FILE}")
    df.to_csv(FINAL_MERGED_FILE, index=False)
    logging.info("Multi-class data recreation complete.")

if __name__ == "__main__":
    main()